# Copyright (c) Open-MMLab. All rights reserved.
#
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import chain

from torch.nn.parallel import DataParallel

from .scatter_gather import scatter_kwargs


class MMDataParallel(DataParallel):
    """The DataParallel module that supports DataContainer.

    MMDataParallel has two main differences with PyTorch DataParallel:

    - It supports a custom type :class:`DataContainer` which allows more
      flexible control of input data during both GPU and CPU inference.
    - It implement two more APIs ``train_step()`` and ``val_step()``.

    Args:
        module (:class:`nn.Module`): Module to be encapsulated.
        device_ids (list[int]): Device IDS of modules to be scattered to.
            Defaults to None when GPU is not available.
        output_device (str | int): Device ID for output. Defaults to None.
        dim (int): Dimension used to scatter the data. Defaults to 0.
    """

    def __init__(self, *args, dim=0, **kwargs):
        super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
        self.dim = dim

    def forward(self, *inputs, **kwargs):
        """Override the original forward function.

        The main difference lies in the CPU inference where the datas in
        :class:`DataContainers` will still be gathered.
        """
        if not self.device_ids:
            # We add the following line thus the module could gather and
            # convert data containers as those in GPU inference
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
            return self.module(*inputs[0], **kwargs[0])
        else:
            return super().forward(*inputs, **kwargs)

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    # def train_step(self, *inputs, **kwargs):
    #     if not self.device_ids:
    #         # We add the following line thus the module could gather and
    #         # convert data containers as those in GPU inference
    #         inputs, kwargs = self.scatter(inputs, kwargs, [-1])
    #         return self.module.train_step(*inputs[0], **kwargs[0])
    #
    #     assert len(self.device_ids) == 1, \
    #         ('MMDataParallel only supports single GPU training, if you need to'
    #          ' train with multiple GPUs, please use MMDistributedDataParallel'
    #          'instead.')
    #
    #     for t in chain(self.module.parameters(), self.module.buffers()):
    #         if t.device != self.src_device_obj:
    #             raise RuntimeError(
    #                 'module must have its parameters and buffers '
    #                 f'on device {self.src_device_obj} (device_ids[0]) but '
    #                 f'found one of them on device: {t.device}')
    #
    #     inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    #     return self.module.train_step(*inputs[0], **kwargs[0])
    #
    # def val_step(self, *inputs, **kwargs):
    #     if not self.device_ids:
    #         # We add the following line thus the module could gather and
    #         # convert data containers as those in GPU inference
    #         inputs, kwargs = self.scatter(inputs, kwargs, [-1])
    #         return self.module.val_step(*inputs, **kwargs)
    #
    #     assert len(self.device_ids) == 1, \
    #         ('MMDataParallel only supports single GPU training, if you need to'
    #          ' train with multiple GPUs, please use MMDistributedDataParallel'
    #          ' instead.')
    #
    #     for t in chain(self.module.parameters(), self.module.buffers()):
    #         if t.device != self.src_device_obj:
    #             raise RuntimeError(
    #                 'module must have its parameters and buffers '
    #                 f'on device {self.src_device_obj} (device_ids[0]) but '
    #                 f'found one of them on device: {t.device}')
    #
    #     inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    #     return self.module.val_step(*inputs[0], **kwargs[0])

